import logging

import torch
import triton
from triton import language as tl

from flag_gems.utils import triton_lang_extension as tle
from flag_gems.utils.pointwise_dynamic import pointwise_dynamic
from flag_gems.utils.shape_utils import c_contiguous_stride
from flag_gems.utils.tensor_wrapper import StridedBuffer

logger = logging.getLogger(__name__)


@pointwise_dynamic(num_inputs=1, promotion_methods=[(0, "DEFAULT")])
@triton.jit
def copy_func(x):
    return x


def repeat_interleave_self_int(inp, repeats, dim=None, *, output_size=None):
    logger.debug("METAX GEMS REPEAT_INTERLEAVE_SELF_INT")
    if dim is None:
        inp = inp.flatten()
        dim = 0
    else:
        if (dim < -inp.ndim) or (dim >= inp.ndim):
            raise IndexError(
                "Dimension out of range (expected to be in range of [{}, {}], but got {})".format(
                    -inp.ndim, inp.ndim - 1, dim
                )
            )
    inp_shape = list(inp.shape)
    inp_stride = list(inp.stride())
    output_shape = list(inp.shape)

    if dim < 0:
        dim = dim + len(inp_shape)

    output_shape[dim] *= repeats

    if output_size is not None and output_size != output_shape[dim]:
        raise RuntimeError(
            "repeat_interleave: Invalid output_size, expected {} but got {}".format(
                output_shape[dim], output_size
            )
        )

    output = torch.empty(output_shape, dtype=inp.dtype, device=inp.device)

    if repeats == 0:
        return output

    in_view_stride = inp_stride[: dim + 1] + [0] + inp_stride[dim + 1 :]
    out_view_shape = inp_shape[: dim + 1] + [repeats] + inp_shape[dim + 1 :]
    out_view_stride = c_contiguous_stride(out_view_shape)

    in_view = StridedBuffer(inp, out_view_shape, in_view_stride)
    out_view = StridedBuffer(output, out_view_shape, out_view_stride)
    ndim = len(out_view_shape)
    copy_func.instantiate(ndim)(in_view, out0=out_view)
    return output


@triton.jit
def repeat_interleave_tensor_kernel(
    repeats_ptr, cumsum_ptr, out_ptr, size, BLOCK_SIZE: tl.constexpr
):
    pid = tle.program_id(0)
    mask = pid < size
    cumsum = tl.load(cumsum_ptr + pid, mask, other=0)
    repeats = tl.load(repeats_ptr + pid, mask, other=0)
    out_offset = cumsum - repeats

    tl.device_assert(repeats >= 0, "repeats can not be negative")

    out_ptr += out_offset
    for start_k in range(0, repeats, BLOCK_SIZE):
        offsets_k = start_k + tl.arange(0, BLOCK_SIZE)
        mask_k = offsets_k < repeats
        tl.store(out_ptr + offsets_k, pid, mask=mask_k)


@triton.jit
def fused_repeat_and_index_select_kernel(
    inp, out, M, N, repeats, BLOCK_SIZE: tl.constexpr
):
    pid = tle.program_id(0)
    row_idx_mask = pid > 0
    start_row_idx = tl.load(repeats + pid - 1, mask=row_idx_mask, other=0)
    end_row_idx = tl.load(repeats + pid)

    num_of_rows = end_row_idx - start_row_idx
    if num_of_rows == 0:
        return

    inp_row_offset = pid * M
    for m in range(0, tl.cdiv(M, BLOCK_SIZE)):
        cols_offsets = m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
        mask = cols_offsets < M
        cur_inp = tl.load(inp + inp_row_offset + cols_offsets, mask=mask, other=0.0)

        for cur_row_in_pid in range(0, num_of_rows):
            output_row_index = start_row_idx + cur_row_in_pid
            output_row_offsets = output_row_index * M + cols_offsets
            tl.store(out + output_row_offsets, cur_inp, mask=mask)


def repeat_interleave_tensor(repeats, *, output_size=None):
    logger.debug("METAX GEMS REPEAT_INTERLEAVE_TENSOR")

    assert repeats.ndim == 1, "repeat_interleave only accept 1D vector as repeat"

    cumsum = repeats.cumsum(axis=0)
    result_size = cumsum[-1].item()

    assert result_size >= 0, "repeats can not be negative"

    out = torch.empty((result_size,), dtype=repeats.dtype, device=repeats.device)
    size = repeats.size(0)

    grid = (size,)
    BLOCK_SIZE = 32
    repeat_interleave_tensor_kernel[grid](
        repeats,
        cumsum,
        out,
        size,
        BLOCK_SIZE=BLOCK_SIZE,
        num_warps=1,
    )
    return out


def fused_repeat_and_index_select(inp, repeats, dim):
    logger.debug("METAX GEMS FUSED_REPEAT_AND_INDEX_SELECT")

    assert repeats.ndim == 1, "repeat_interleave only accept 1D vector as repeat"

    repeats = repeats.cumsum(axis=0)
    index_len = repeats[-1].item()

    out_shape = list(inp.shape)
    out_shape[dim] = index_len
    out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device)

    N = inp.shape[dim]
    M = inp.numel() // N

    grid = (inp.shape[dim],)
    BLOCK_SIZE = min(triton.next_power_of_2(M), 4096)

    fused_repeat_and_index_select_kernel[grid](
        inp, out, M, N, repeats, BLOCK_SIZE=BLOCK_SIZE
    )
    return out


def repeat_interleave_self_tensor(inp, repeats, dim=None, *, output_size=None):
    logger.debug("METAX GEMS REPEAT_INTERLEAVE_SELF_TENSOR")

    if dim is None:
        inp = inp.flatten()
        dim = 0
    else:
        if (dim < -inp.ndim) or (dim >= inp.ndim):
            raise IndexError(
                "Dimension out of range (expected to be in range of [{}, {}], but got {})".format(
                    -inp.ndim, inp.ndim - 1, dim
                )
            )

    if repeats.ndim == 0 or (repeats.ndim == 1 and repeats.size(0) == 1):
        return repeat_interleave_self_int(
            inp, repeats.item(), dim=dim, output_size=output_size
        )
    elif repeats.ndim > 1:
        raise RuntimeError("repeats must be 0-dim or 1-dim tensor")

    inp_shape = list(inp.shape)
    if dim < 0:
        dim = dim + len(inp_shape)

    if repeats.size(0) != inp_shape[dim]:
        raise RuntimeError(
            "repeats must have the same size as input along dim, but got \
                repeats.size(0) = {} and input.size({}) = {}".format(
                repeats.size(0), dim, inp_shape[dim]
            )
        )

    if inp.ndim == 2 and dim == 0:
        res = fused_repeat_and_index_select(inp, repeats, dim)
    else:
        indices = repeat_interleave_tensor(repeats)
        res = torch.index_select(inp, dim, indices)

    return res
